Rikiya Yamashita
Inspired by this Washington Post article by Harry Stevens
referred to this implementation for elastic collision by Christian Hill
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from matplotlib import animation
from matplotlib.gridspec import GridSpec
from itertools import combinations
from IPython.display import HTML
%matplotlib inline
import matplotlib
matplotlib.rcParams['animation.embed_limit'] = 2**8
from matplotlib import colors as mcolors
colors = dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS)
class Particle:
def __init__(self, x, y, vx, vy, radius=0.01, prop=0):
self.r = np.array((x, y))
self.v = np.array((vx, vy))
self.radius = radius
self.mass = self.radius**2
self.prop = prop # 0: healthy, 1: infected, 2: recovered, 3: dead
self.prop_dict = {0: 'silver', 1: 'indianred' , 2: 'cadetblue', 3: 'dimgray'}
self.styles = {'facecolor': self.prop_dict[self.prop]}
self.timer = 0
@property
def x(self):
return self.r[0]
@x.setter
def x(self, value):
self.r[0] = value
@property
def y(self):
return self.r[1]
@y.setter
def y(self, value):
self.r[1] = value
@property
def vx(self):
return self.v[0]
@vx.setter
def vx(self, value):
self.v[0] = value
@property
def vy(self):
return self.v[1]
@vy.setter
def vy(self, value):
self.v[1] = value
def overlaps(self, other):
return np.hypot(*(self.r - other.r)) < self.radius + other.radius
def draw(self, ax):
circle = Circle(xy=self.r, radius=self.radius, **self.styles)
ax.add_patch(circle)
return circle
def advance(self, dt):
self.r += self.v * dt
class Simulation:
ParticleClass = Particle
def __init__(self, n, radius=0.05, prop=0, transmission_rate=0.8,
disease_duration=40, death_rate=0.2, dt=0.04, dttype=None):
self.init_particles(n, radius, prop)
self.dt = dt
self.dttype = dttype
self.transmission_rate = transmission_rate
self.disease_duration = disease_duration
self.death_rate = death_rate
self.nparticles = n
self.healthy = np.array(1.)
self.infected = np.array(0.)
self.recovered = np.array(0.)
self.dead = np.array(0.)
def place_particle(self, rad, prop):
x, y = rad + (1 - 2*rad) * np.random.random(2)
vr = 0.1 * np.sqrt(np.random.random()) + 0.05
vphi = 2*np.pi * np.random.random()
vx, vy = vr * np.cos(vphi), vr * np.sin(vphi)
particle = self.ParticleClass(x, y, vx, vy, rad, prop)
for p2 in self.particles:
if p2.overlaps(particle):
break
else:
self.particles.append(particle)
return True
return False
def init_particles(self, n, radius, prop):
try:
iterator = iter(radius)
assert n == len(radius)
except TypeError:
def r_gen(n, radius):
for i in range(n):
yield radius
radius = r_gen(n, radius)
self.n = n
self.particles = []
for i, rad in enumerate(radius):
if i == 0:
self.place_particle(rad, prop=1)
else:
while not self.place_particle(rad, prop):
pass
def change_velocities(self, p1, p2):
m1, m2 = p1.mass, p2.mass
M = m1 + m2
r1, r2 = p1.r, p2.r
d = np.linalg.norm(r1 - r2)**2
v1, v2 = p1.v, p2.v
if p1.vx*p1.vy*p2.vx*p2.vy != 0:
u1 = v1 - 2*m2 / M * np.dot(v1-v2, r1-r2) / d * (r1 - r2)
u2 = v2 - 2*m1 / M * np.dot(v2-v1, r2-r1) / d * (r2 - r1)
p1.v = u1
p2.v = u2
else:
pass
def handle_collisions(self):
pairs = combinations(range(self.n), 2)
for i,j in pairs:
if self.particles[i].overlaps(self.particles[j]):
self.change_velocities(self.particles[i], self.particles[j])
props = [0, 1]
if self.particles[i].prop==1 and self.particles[j].prop==0:
self.particles[j].prop = np.random.choice(props, p=[1-self.transmission_rate, self.transmission_rate])
self.circles[j].set_color(self.particles[j].prop_dict[self.particles[j].prop])
elif self.particles[i].prop==0 and self.particles[j].prop==1:
self.particles[i].prop = np.random.choice(props, p=[1-self.transmission_rate, self.transmission_rate])
self.circles[i].set_color(self.particles[i].prop_dict[self.particles[i].prop])
def handle_boundary_collisions(self, p):
if p.x - p.radius < 0:
p.x = p.radius
p.vx = -p.vx
if p.x + p.radius > 1:
p.x = 1-p.radius
p.vx = -p.vx
if p.y - p.radius < 0:
p.y = p.radius
p.vy = -p.vy
if p.y + p.radius > 1:
p.y = 1-p.radius
p.vy = -p.vy
def advance_animation(self):
for i, p in enumerate(self.particles):
p.advance(self.dt)
self.handle_boundary_collisions(p)
self.circles[i].center = p.r
self.handle_collisions()
return self.circles
def init(self):
self.circles = []
for particle in self.particles:
self.circles.append(particle.draw(self.ax1))
self.texts = []
self.texts.append(self.ax2.text(0.25,0.725, 'healthy', ha="left", va="center", fontsize=14, color='black'))
self.texts.append(self.ax2.text(0.75,0.725, '', ha="right", va="center", fontsize=14, color='silver'))
self.texts.append(self.ax2.text(0.25,0.575, 'infected', ha="left", va="center", fontsize=14, color='black'))
self.texts.append(self.ax2.text(0.75,0.575, '', ha="right", va="center", fontsize=14, color='indianred'))
self.texts.append(self.ax2.text(0.25,0.425, 'recovered', ha="left", va="center", fontsize=14, color='black'))
self.texts.append(self.ax2.text(0.75,0.425, '', ha="right", va="center", fontsize=14, color='cadetblue'))
self.texts.append(self.ax2.text(0.25,0.275, 'dead', ha="left", va="center", fontsize=14, color='black'))
self.texts.append(self.ax2.text(0.75,0.275, '', ha="right", va="center", fontsize=14, color='dimgray'))
self.bar = self.ax3.bar([], [], width=1)
return self.circles + self.texts + list(self.bar)
def animate(self, i):
if self.dttype == 'stop1/4':
if i >= 30:
self.dt = 0.12
elif self.dttype == 'stop1/2':
if i >= 60:
self.dt = 0.12
elif self.dttype == 'stop3/4':
if i >= 90:
self.dt = 0.12
elif self.dttype == 'lightswitch':
if i >= 30 and i < 60:
self.dt = 0.12
elif i >= 60 and i < 90:
self.dt = 0.04
elif i >= 90:
self.dt = 0.12
self.advance_animation()
healthy = 0
infected = 0
recovered = 0
dead = 0
for idx in range(len(self.circles)):
if self.particles[idx].prop == 1:
self.particles[idx].timer += 1
if (self.particles[idx].prop == 1 and self.particles[idx].timer > self.disease_duration):
props = [2, 3]
self.particles[idx].prop = np.random.choice(props, p=[1-self.death_rate, self.death_rate])
self.circles[idx].set_color(self.particles[idx].prop_dict[self.particles[idx].prop])
if self.particles[idx].prop == 3:
self.particles[idx].x, self.particles[idx].y = 10 + 10 * np.random.random(2)
self.particles[idx].vx, self.particles[idx].vy = (0, 0)
self.circles[idx].center = self.particles[idx].r
if self.particles[idx].prop == 0:
healthy += 1
elif self.particles[idx].prop == 1:
infected += 1
elif self.particles[idx].prop == 2:
recovered += 1
elif self.particles[idx].prop == 3:
dead += 1
self.healthy = np.append(self.healthy, healthy/self.nparticles)
self.infected = np.append(self.infected, infected/self.nparticles)
self.recovered = np.append(self.recovered, recovered/self.nparticles)
self.dead = np.append(self.dead, dead/self.nparticles)
self.texts[1].set_text(str(healthy))
self.texts[3].set_text(str(infected))
self.texts[5].set_text(str(recovered))
self.texts[7].set_text(str(dead))
self.bar = self.ax3.bar([*range(len(self.healthy))], self.infected, width=1, color = colors['indianred'])
self.bar = self.ax3.bar([*range(len(self.healthy))], self.healthy, width=1, bottom=self.infected, color = colors['silver'])
self.bar = self.ax3.bar([*range(len(self.healthy))], self.recovered, width=1, bottom=self.infected+self.healthy, color = colors['cadetblue'])
self.bar = self.ax3.bar([*range(len(self.healthy))], self.dead, width=1, bottom=self.infected+self.healthy+self.recovered, color = colors['dimgray'])
return self.circles + self.texts + list(self.bar)
def setup_animation(self):
self.fig = plt.figure(figsize=(10, 5))
gs = GridSpec(2, 2, figure=self.fig)
self.ax1 = self.fig.add_subplot(gs[:, 0])
self.ax2 = self.fig.add_subplot(gs[0, 1])
self.ax3 = self.fig.add_subplot(gs[1, 1])
for s in ['top','bottom','left','right']:
self.ax1.spines[s].set_linewidth(1)
self.ax1.set_xlim(0, 1)
self.ax1.set_ylim(0, 1)
self.ax1.xaxis.set_ticks([])
self.ax1.yaxis.set_ticks([])
self.ax2.axis("off")
self.ax3.spines['right'].set_visible(False)
self.ax3.spines['top'].set_visible(False)
self.ax3.set_xlim(0, self.nframes)
self.ax3.set_ylim(0, 1)
self.ax3.xaxis.set_ticks([])
self.ax3.yaxis.set_ticks([])
def save_or_show_animation(self, anim, save, filename='flattenthecurve_jupyter.gif'):
if save:
anim.save(filename, writer='imagemagick', fps=60)
# anim.save(filename, writer='ffmpeg', fps=60)
else:
plt.show()
def do_animation(self, save=False, nframes=1000, interval=1000/60, filename='flattenthecurve_jupyter.gif'):
self.nframes = nframes
self.setup_animation()
anim = animation.FuncAnimation(self.fig, self.animate,
init_func=self.init, frames=nframes, interval=interval, blit=True)
self.save_or_show_animation(anim, save, filename)
return anim
nparticles = 30
radii = 0.05
prop = 0
nframes = 120
transmission_rate = 0.8
disease_duration = 40
death_rate = 0.2
dt = 0.12
dttype = None # 'stop1/4', 'stop1/2', 'stop3/4', 'lightswitch'
sim = Simulation(nparticles, radii, prop, transmission_rate, disease_duration, death_rate, dt, dttype)
anim = sim.do_animation(save=False, nframes=nframes)
# Without Social Distancing
HTML(anim.to_jshtml())
nparticles = 30
radii = 0.05
prop = 0
nframes = 120
transmission_rate = 0.8
disease_duration = 40
death_rate = 0.2
dt = 0.04
dttype = None # 'stop1/4', 'stop1/2', 'stop3/4', 'lightswitch'
sim = Simulation(nparticles, radii, prop, transmission_rate, disease_duration, death_rate, dt, dttype)
anim = sim.do_animation(save=False, nframes=nframes)
# With Social Distancing
HTML(anim.to_jshtml())
nparticles = 30
radii = 0.05
prop = 0
nframes = 120
transmission_rate = 0.8
disease_duration = 40
death_rate = 0.2
dt = 0.04
dttype = 'stop1/4' # None, 'stop1/2', 'stop3/4', 'lightswitch'
sim = Simulation(nparticles, radii, prop, transmission_rate, disease_duration, death_rate, dt, dttype)
anim = sim.do_animation(save=False, nframes=nframes)
# With Social Distancing for 1/4 Period
HTML(anim.to_jshtml())
nparticles = 30
radii = 0.05
prop = 0
nframes = 120
transmission_rate = 0.8
disease_duration = 40
death_rate = 0.2
dt = 0.04
dttype = 'stop1/2' # None, 'stop1/4', 'stop3/4', 'lightswitch'
sim = Simulation(nparticles, radii, prop, transmission_rate, disease_duration, death_rate, dt, dttype)
anim = sim.do_animation(save=False, nframes=nframes)
# With Social Distancing for 1/2 Period
HTML(anim.to_jshtml())
nparticles = 30
radii = 0.05
prop = 0
nframes = 120
transmission_rate = 0.8
disease_duration = 40
death_rate = 0.2
dt = 0.04
dttype = 'stop3/4' # None, 'stop1/4', 'stop1/2', 'lightswitch'
sim = Simulation(nparticles, radii, prop, transmission_rate, disease_duration, death_rate, dt, dttype)
anim = sim.do_animation(save=False, nframes=nframes)
# With Social Distancing for 3/4 Period
HTML(anim.to_jshtml())
nparticles = 30
radii = 0.05
prop = 0
nframes = 240
transmission_rate = 0.8
disease_duration = 40
death_rate = 0.2
dt = 0.04
dttype = 'stop3/4' # None, 'stop1/4', 'stop1/2', 'lightswitch'
sim = Simulation(nparticles, radii, prop, transmission_rate, disease_duration, death_rate, dt, dttype)
anim = sim.do_animation(save=False, nframes=nframes)
# With Social Distancing for 3/4 Period (extended time frame)
HTML(anim.to_jshtml())
nparticles = 30
radii = 0.05
prop = 0
nframes = 120
transmission_rate = 0.8
disease_duration = 40
death_rate = 0.2
dt = 0.04
dttype = 'lightswitch' # None, 'stop1/4', 'stop1/2', 'stop3/4'
sim = Simulation(nparticles, radii, prop, transmission_rate, disease_duration, death_rate, dt, dttype)
anim = sim.do_animation(save=False, nframes=nframes)
# With Social Distancing: Lightswitch Method (Social Distancing: On --> Off --> On --> Off)
HTML(anim.to_jshtml())